#include <Rcpp.h>

using namespace Rcpp;

// Reimplement plogis and qlogis
inline NumericVector PLOGIS(NumericVector q,
                            double loc,
                            double scl)
{
  return 1.0 / (1.0 + exp((loc - q) / scl));
}
inline NumericVector QLOGIS(NumericVector p,
                            double loc,
                            double scl)
{
  return loc - scl * log((1.0 / p) - 1.0);
}

// Reimplement which()
IntegerVector WHICH(LogicalVector x,
                    bool zero_index = true)
{
    IntegerVector ind = seq_along(x);
    ind = ind[x];
    if (zero_index)
        ind = ind - 1;
    return ind;
}


double wew_a ( double x )
//****************************************************************************80
//
//  Purpose:
//
//    WEW_A estimates Lambert's W function.
//
//  Discussion:
//
//    For a given X, this routine estimates the solution W of Lambert's
//    equation:
//
//      X = W * EXP ( W )
//
//    This routine has higher accuracy than WEW_B.
//
//  Modified:
//
//    11 June 2014
//
//  Reference:
//
//    Fred Fritsch, R Shafer, W Crowley,
//    Algorithm 443: Solution of the transcendental equation w e^w = x,
//    Communications of the ACM,
//    October 1973, Volume 16, Number 2, pages 123-124.
//
//  Parameters:
//
//    Input, double X, the argument of W(X)
//
//    Output, double &EN, the last relative correction to W(X).
//
//    Output, double WEW_A, the estimated value of W(X).
//
{
  const double c1 = 4.0 / 3.0;
  const double c2 = 7.0 / 3.0;
  const double c3 = 5.0 / 6.0;
  const double c4 = 2.0 / 3.0;
  double en;
  double f;
  double temp;
  double temp2;
  double wn;
  double y;
  double zn;
//
//  Initial guess.
//
  f = log ( x );

  if ( x <= 6.46 )
  {
    wn = x * ( 1.0 + c1 * x ) / ( 1.0 + x * ( c2 + c3 * x ) );
    zn = f - wn - log ( wn );
  }
  else
  {
    wn = f;
    zn = - log ( wn );
  }
//
//  Iteration 1.
//
  temp = 1.0 + wn;
  y = 2.0 * temp * ( temp + c4 * zn ) - zn;
  wn = wn * ( 1.0 + zn * y / ( temp * ( y - zn ) ) );
//
//  Iteration 2.
//
  zn = f - wn - log ( wn );
  temp = 1.0 + wn;
  temp2 = temp + c4 * zn;
  en = zn * temp2 / ( temp * temp2 - 0.5 * zn );
  wn = wn * ( 1.0 + en );

  return wn;
}

NumericVector LW(NumericVector x)
{
  int n = x.size();
  NumericVector ans(n);
  for (int i = 0; i < n; i++) {
    ans[i] = wew_a(x[i]);
  }
  return ans;
}

// [[Rcpp::export]]
NumericVector contest_eq(NumericVector ratio,
                         double xmax)
{
  int n = ratio.size();
  LogicalVector active(n, true);

  // Loop through potential sets of contributors, sequentially eliminating
  // those with highest cost-effectiveness ratio until we hit an equilibrium
  NumericVector prop(n);
  bool any_negative;
  do {
    // Calculate total cost-effectiveness ratio for current active set
    NumericVector ratio_active = ratio[active];
    double total = sum(ratio_active);
    total = (total > xmax) ? xmax : total;

    double ratio_max = 0;
    int which_max = 0;
    int n_active = sum(active);
    any_negative = false;
    for (int i = 0; i < n; i++) {
      // Update running maximum among active types
      if (active[i] && ratio[i] > ratio_max) {
        ratio_max = ratio[i];
        which_max = i;
      }

      // Calculate equilibrium proportions given the current active set
      prop[i] = (active[i]) ? (total - (n_active - 1) * ratio[i]) / total : 0;

      // Negative proportion => no equilibrium with current active set
      if (prop[i] < 0) {
        any_negative = true;
      }
    }

    // Deactivate the highest active type
    active[which_max] = false;
  } while (any_negative);

  return prop;
}

// [[Rcpp::export]]
List contest_vals(IntegerVector dispute_level_id,
                  IntegerVector state_level_id,
                  IntegerVector side_a,
                  NumericVector ratio,
                  double xmax)
{
  int n_dispute = dispute_level_id.size();
  int n_states = state_level_id.size();
  NumericVector prob_win_a(n_dispute);
  NumericVector pi_a_vec(n_dispute);
  NumericVector pi_b_vec(n_dispute);
  NumericVector p_indiv(n_states);

  for (int i = 0; i < n_dispute; i++) {
    // Calculate equilibrium probabilities for this dispute
    int dispute_i = dispute_level_id[i];
    NumericVector ratio_i = ratio[state_level_id == dispute_i];
    NumericVector prop_i = contest_eq(ratio_i, xmax);
    NumericVector u_war_i = prop_i * prop_i;

    // Calculate totals for each side
    IntegerVector side_a_i = side_a[state_level_id == dispute_i];
    NumericVector prop_a = prop_i[side_a_i == 1];
    NumericVector u_war_a = u_war_i[side_a_i == 1];
    NumericVector u_war_b = u_war_i[side_a_i == 0];
    double p_a = sum(prop_a);
    if (p_a > 1.0)          // deal with possible floating point error
      p_a = 1.0;
    double pi_a = sum(u_war_a);
    double pi_b = sum(u_war_b);

    prob_win_a[i] = p_a;
    pi_a_vec[i] = pi_a;
    pi_b_vec[i] = pi_b;
    p_indiv[state_level_id == dispute_i] = prop_i;
  }

  return List::create(Named("prob_win_a") = prob_win_a,
                      Named("pi_a") = pi_a_vec,
                      Named("pi_b") = pi_b_vec,
                      Named("p_indiv") = p_indiv);
}

// Calculate derivatives of p_win_a, pi_a, and pi_b for a single dispute
//
// The i'th element of each returned vector is d[quantity]/d[ratio_i]
List contest_derivs(NumericVector ratio,
                    NumericVector p,
                    IntegerVector side_a)
{
  int n = ratio.size();
  LogicalVector active = p > 0;
  int J = sum(active);
  NumericVector ratio_active = ratio[active];
  double total = sum(ratio_active);

  // Output vectors
  NumericVector d_p_a(n, 0.0);
  NumericVector d_pi_a(n, 0.0);
  NumericVector d_pi_b(n, 0.0);

  // Calculate d p_a / d r_i, d pi_a / d r_i, d pi_b / d r_i
  for (int i = 0; i < n; i++) {
    double d_self = (1.0 - J) * (total - ratio[i]) / (total * total);
    double d_other = (J - 1.0) * ratio[i] / (total * total);

    for (int j = 0; j < n; j++) {
      // Calculate d p_i / d r_j and d pi_i / d r_j
      double d_p_i_d_r_j = 0.0;
      if (active[i] && active[j]) {
        d_p_i_d_r_j = (i == j) ? d_self : d_other;
      }
      double d_pi_i_d_r_j = 2.0 * p[i] * d_p_i_d_r_j;

      // If i in A, then add d p_i / d r_j to d p_a / d r_j, and analogously for pi_a
      if (side_a[i]) {
        d_p_a[j] += d_p_i_d_r_j;
        d_pi_a[j] += d_pi_i_d_r_j;
      } else {
        d_pi_b[j] += d_pi_i_d_r_j;
      }
    }

  }

  return List::create(Named("d_p_a") = d_p_a,
                      Named("d_pi_a") = d_pi_a,
                      Named("d_pi_b") = d_pi_b);
}

// Calculate derivatives of p_win_a, pi_a, and pi_b with respect to beta and
// gamma for all disputes
List all_contest_derivs(IntegerVector dispute_level_id,
                        IntegerVector state_level_id,
                        IntegerVector side_a,
                        NumericVector ratio,
                        NumericVector p_indiv,
                        NumericMatrix X,
                        NumericMatrix Z)
{
  int n_X = X.ncol();
  int n_Z = Z.ncol();
  int n_disp = dispute_level_id.size();

  NumericMatrix d_pwin(n_disp, n_X + n_Z);
  NumericMatrix d_pi_a(n_disp, n_X + n_Z);
  NumericMatrix d_pi_b(n_disp, n_X + n_Z);

  // Loop through disputes (n = 1, ..., N)
  for (int n = 0; n < n_disp; n++) {
    // Extract relevant quantities for this dispute
    int dispute_n = dispute_level_id[n];
    IntegerVector idx_n = WHICH(state_level_id == dispute_n);
    NumericVector ratio_n = ratio[idx_n];
    NumericVector p_n = p_indiv[idx_n];
    IntegerVector side_a_n = side_a[idx_n];
    int n_players = idx_n.size();

    // Calculate derivatives of equilibrium quantities w.r.t. ratios
    List d_n = contest_derivs(ratio_n, p_n, side_a_n);
    NumericVector d_pwin_n = d_n["d_p_a"];
    NumericVector d_pi_a_n = d_n["d_pi_a"];
    NumericVector d_pi_b_n = d_n["d_pi_b"];

    // Loop through coefficients (k = 1, ..., K)
    for (int k = 0; k < n_X + n_Z; k++) {
      d_pwin(n, k) = 0.0;
      d_pi_a(n, k) = 0.0;
      d_pi_b(n, k) = 0.0;

      // Loop through players (i = 1, ..., I)
      for (int i = 0; i < n_players; i++) {
        // Calculate derivative of i's ratio w.r.t. given coefficient
        int i_row = idx_n[i];
        double mult = (k < n_X) ?
          -1.0 * X(i_row, k) :
          Z(i_row, k - n_X);
        double d_ratio_i_d_cf_k = mult * ratio_n[i];

        // Calculate contributions to the derivatives of equilibrium quantities
        // w.r.t. coefficients
        d_pwin(n, k) += d_pwin_n[i] * d_ratio_i_d_cf_k;
        d_pi_a(n, k) += d_pi_a_n[i] * d_ratio_i_d_cf_k;
        d_pi_b(n, k) += d_pi_b_n[i] * d_ratio_i_d_cf_k;
      }
    }
  }

  return List::create(Named("d_pwin") = d_pwin,
                      Named("d_pi_a") = d_pi_a,
                      Named("d_pi_b") = d_pi_b);
}

// [[Rcpp::export]]
NumericVector prob_war(NumericVector p_theta_a,
                       NumericVector loc_a,
                       NumericVector loc_b,
                       NumericVector scl_a,
                       NumericVector scl_b,
                       NumericVector pi_a,
                       NumericVector pi_b)
{
  int n_disp = loc_a.size();
  int n_halton = p_theta_a.size();
  NumericVector ans(n_disp);

  for (int i = 0; i < n_disp; i++) {
    NumericVector theta_a = QLOGIS(p_theta_a, loc_a[i], scl_a[i]);
    NumericVector z = ((1.0 - pi_a[i] - pi_b[i] - loc_b[i] - theta_a) / scl_b[i]) - 1.0;
    z = exp(z);
    z = LW(z);
    NumericVector x = pi_a[i] + theta_a + scl_b[i] * (1.0 + z);
    NumericVector pr_reject = 1.0 - PLOGIS(1.0 - x - pi_b[i], loc_b[i], scl_b[i]);
    ans[i] = mean(pr_reject);
  }

  return ans;
}

// Doing a separate function for the counterfactuals, mainly because I don't
// want to accidentally screw up the estimation by refactoring that code
//
// [[Rcpp::export]]
List prob_war_and_x(NumericVector p_theta_a,
                    NumericVector loc_a,
                    NumericVector loc_b,
                    NumericVector scl_a,
                    NumericVector scl_b,
                    NumericVector pi_a,
                    NumericVector pi_b)
{
  int n_disp = loc_a.size();
  int n_halton = p_theta_a.size();
  NumericVector ans_p(n_disp);
  NumericVector ans_x(n_disp);

  for (int i = 0; i < n_disp; i++) {
    NumericVector theta_a = QLOGIS(p_theta_a, loc_a[i], scl_a[i]);
    NumericVector z = ((1.0 - pi_a[i] - pi_b[i] - loc_b[i] - theta_a) / scl_b[i]) - 1.0;
    z = exp(z);
    z = LW(z);
    NumericVector x = pi_a[i] + theta_a + scl_b[i] * (1.0 + z);
    NumericVector pr_reject = 1.0 - PLOGIS(1.0 - x - pi_b[i], loc_b[i], scl_b[i]);
    ans_p[i] = mean(pr_reject);
    ans_x[i] = mean(x);
  }

  return List::create(Named("pr_war") = ans_p,
                      Named("x_star") = ans_x);
}

// Calculate derivatives of Pr(war) with respect to model parameters
List prob_war_derivs(NumericVector p_theta_a,
                     NumericVector loc_a,
                     NumericVector loc_b,
                     NumericVector scl_a,
                     NumericVector scl_b,
                     NumericVector pi_a,
                     NumericVector pi_b) {
  int n_disp = loc_a.size();
  int n_halton = p_theta_a.size();

  // Results storage
  NumericVector d_prwar_d_loc_a(n_disp, 0.0);
  NumericVector d_prwar_d_loc_b(n_disp, 0.0);
  NumericVector d_prwar_d_scl_a(n_disp, 0.0);
  NumericVector d_prwar_d_scl_b(n_disp, 0.0);
  NumericVector d_prwar_d_pi_a(n_disp, 0.0);
  NumericVector d_prwar_d_pi_b(n_disp, 0.0);

  // Loop through disputes
  for (int i = 0; i < n_disp; i++) {
    NumericVector theta_a = QLOGIS(p_theta_a, loc_a[i], scl_a[i]);

    // Calculate derivatives of optimal offer
    NumericVector z = ((1.0 - pi_a[i] - pi_b[i] - loc_b[i] - theta_a) / scl_b[i]) - 1.0;
    NumericVector e_z = exp(z);
    NumericVector W_e_z = LW(e_z);
    NumericVector d_W_e_z = W_e_z / (e_z * (1.0 + W_e_z));
    NumericVector d_x_d_theta_a = 1.0 - d_W_e_z * e_z;
    NumericVector d_x_d_pi_b = d_x_d_theta_a - 1.0;
    NumericVector d_x_d_scl_b = 1.0 + W_e_z - (1.0 + z) * d_W_e_z * e_z;

    // Calculate derivatives of war probabilities
    NumericVector x = pi_a[i] + theta_a + scl_b[i] * (1.0 + W_e_z);
    NumericVector lambda = (1.0 - loc_b[i] - x - pi_b[i]) / scl_b[i];
    lambda = exp(lambda);
    lambda = lambda / ((1.0 + lambda) * (1.0 + lambda));
    NumericVector d_psi_d_theta_a = lambda * d_x_d_theta_a / scl_b[i];
    NumericVector d_psi_d_pi_b = lambda * (1.0 + d_x_d_pi_b) / scl_b[i];
    NumericVector d_psi_d_scl_b = lambda *
      (d_x_d_scl_b * scl_b[i] + 1 - loc_b[i] - x - pi_b[i]) / (scl_b[i] * scl_b[i]);

    // Calculate derivatives of theta_A w.r.t. distributional parameters
    double d_theta_a_d_loc_a = 1.0;
    NumericVector d_theta_a_d_scl_a = -1.0 * log((1.0 / p_theta_a) - 1.0);

    // Take and store averages
    d_prwar_d_loc_a[i] = mean(d_psi_d_theta_a * d_theta_a_d_loc_a);
    d_prwar_d_loc_b[i] = mean(d_psi_d_pi_b);
    d_prwar_d_scl_a[i] = mean(d_psi_d_theta_a * d_theta_a_d_scl_a);
    d_prwar_d_scl_b[i] = mean(d_psi_d_scl_b);
    d_prwar_d_pi_a[i] = mean(d_psi_d_theta_a);
    d_prwar_d_pi_b[i] = mean(d_psi_d_pi_b);
  }

  return List::create(Named("d_prwar_d_loc_a") = d_prwar_d_loc_a,
                      Named("d_prwar_d_loc_b") = d_prwar_d_loc_b,
                      Named("d_prwar_d_scl_a") = d_prwar_d_scl_a,
                      Named("d_prwar_d_scl_b") = d_prwar_d_scl_b,
                      Named("d_prwar_d_pi_a") = d_prwar_d_pi_a,
                      Named("d_prwar_d_pi_b") = d_prwar_d_pi_b);
}

// [[Rcpp::export]]
NumericVector loglik_backend(IntegerVector dispute_level_id,
                             IntegerVector war,
                             IntegerVector win_a,
                             IntegerVector win_b,
                             IntegerVector state_level_id,
                             IntegerVector side_a,
                             LogicalVector orig,
                             NumericVector p_theta_a,
                             NumericVector ratio,
                             NumericVector loc_a,
                             NumericVector loc_b,
                             NumericVector scl_a,
                             NumericVector scl_b,
                             double xmax)
{
  // Calculate each side's probability of victory using all disputants
  List eq_contest_all = contest_vals(dispute_level_id, state_level_id, side_a, ratio, xmax);
  NumericVector pwin = eq_contest_all["prob_win_a"];

  // Calculate war payoffs using only dispute originators, leaving out joiners
  List eq_contest_orig = contest_vals(dispute_level_id, state_level_id[orig], side_a[orig], ratio[orig], xmax);
  NumericVector pi_a = eq_contest_orig["pi_a"];
  NumericVector pi_b = eq_contest_orig["pi_b"];

  // Calculate probability of war
  NumericVector pwar = prob_war(p_theta_a, loc_a, loc_b, scl_a, scl_b, pi_a, pi_b);

  // Calculate log-likelihoods
  int n_disp = dispute_level_id.size();
  NumericVector loglik(n_disp);
  for (int n = 0; n < n_disp; n++) {
    double p = (war[n] == 1) ? pwar[n] : 1.0 - pwar[n];
    if (war[n] == 1 && win_a[n] == 1) {
      p *= pwin[n];
    } else if (war[n] == 1 && win_b[n] == 1) {
      p *= 1.0 - pwin[n];
    }

    loglik[n] = log(p);
  }

  return loglik;
}

// [[Rcpp::export]]
NumericMatrix grad_backend(IntegerVector dispute_level_id,
                           IntegerVector war,
                           IntegerVector win_a,
                           IntegerVector win_b,
                           IntegerVector state_level_id,
                           IntegerVector side_a,
                           LogicalVector orig,
                           NumericVector p_theta_a,
                           NumericVector ratio,
                           NumericVector loc_a,
                           NumericVector loc_b,
                           NumericVector scl_a,
                           NumericVector scl_b,
                           NumericMatrix X,
                           NumericMatrix Z,
                           NumericMatrix X_orig,
                           NumericMatrix Z_orig,
                           NumericMatrix L_a,
                           NumericMatrix L_b,
                           NumericMatrix S_a,
                           NumericMatrix S_b,
                           double xmax)
{
  // Calculate each side's probability of victory using all disputants
  List eq_contest_all = contest_vals(dispute_level_id, state_level_id, side_a, ratio, xmax);
  NumericVector pwin = eq_contest_all["prob_win_a"];
  NumericVector p_indiv_all = eq_contest_all["p_indiv"];

  // Calculate war payoffs using only dispute originators, leaving out joiners
  List eq_contest_orig = contest_vals(dispute_level_id, state_level_id[orig], side_a[orig], ratio[orig], xmax);
  NumericVector pi_a = eq_contest_orig["pi_a"];
  NumericVector pi_b = eq_contest_orig["pi_b"];
  NumericVector p_indiv_orig = eq_contest_orig["p_indiv"];

  // Calculate probability of war
  NumericVector pwar = prob_war(p_theta_a, loc_a, loc_b, scl_a, scl_b, pi_a, pi_b);

  // Calculate derivatives of Pr(A wins) w.r.t. relevant coefficients
  List derivs_contest_all = all_contest_derivs(dispute_level_id, state_level_id, side_a, ratio, p_indiv_all, X, Z);
  NumericMatrix d_pwin_d_bg = derivs_contest_all["d_pwin"];

  // Calculate derivatives of war payoffs w.r.t. relevant coefficients
  List derivs_contest_orig = all_contest_derivs(dispute_level_id, state_level_id[orig], side_a[orig], ratio[orig], p_indiv_orig, X_orig, Z_orig);
  NumericMatrix d_pi_a_d_bg = derivs_contest_orig["d_pi_a"];
  NumericMatrix d_pi_b_d_bg = derivs_contest_orig["d_pi_b"];

  // Calculate derivatives of Pr(war) w.r.t. model parameters
  List derivs_war = prob_war_derivs(p_theta_a, loc_a, loc_b, scl_a, scl_b, pi_a, pi_b);
  NumericVector d_pwar_d_loc_a = derivs_war["d_prwar_d_loc_a"];
  NumericVector d_pwar_d_loc_b = derivs_war["d_prwar_d_loc_b"];
  NumericVector d_pwar_d_scl_a = derivs_war["d_prwar_d_scl_a"];
  NumericVector d_pwar_d_scl_b = derivs_war["d_prwar_d_scl_b"];
  NumericVector d_pwar_d_pi_a = derivs_war["d_prwar_d_pi_a"];
  NumericVector d_pwar_d_pi_b = derivs_war["d_prwar_d_pi_b"];

  // Make the big matrices that do all the work
  int n_disp = dispute_level_id.size();
  int n_X = X.ncol();
  int n_Z = Z.ncol();
  int n_L = L_a.ncol();
  int n_S = S_a.ncol();
  int n_cf = n_X + n_Z + n_L + n_S;
  NumericMatrix grad_ll(n_disp, n_cf);

  // Loop over disputes: n = 1, ..., N
  for (int n = 0; n < n_disp; n++) {
    bool war_n = war[n] == 1;
    bool win_a_n = win_a[n] == 1;
    bool win_b_n = win_b[n] == 1;

    // Loop over structural model coefficients: k = 1, ..., K
    for (int k = 0; k < n_cf; k++) {
      double d_pwin_d_cf, d_pwar_d_cf;

      // Easy part (given what's already done): derivatives of victory prob
      // w.r.t. coefficients --- easy because the location + scale parameters
      // have no effect
      if (k < n_X + n_Z) {
        d_pwin_d_cf = d_pwin_d_bg(n, k);
      } else {
        d_pwin_d_cf = 0.0;
      }

      // Hard part: derivatives of war prob w.r.t. coefficients
      if (k < n_X + n_Z) {  // beta and gamma
        // beta and gamma only contribute to war prob via contributions to pi_a and pi_b
        d_pwar_d_cf = d_pwar_d_pi_a[n] * d_pi_a_d_bg(n, k) +
          d_pwar_d_pi_b[n] * d_pi_b_d_bg(n, k);
      } else if (k < n_X + n_Z + n_L) {  // location parameter coefficients
        int k_col = k - n_X - n_Z;
        double d_loc_a_d_cf = L_a(n, k_col);
        double d_loc_b_d_cf = L_b(n, k_col);

        d_pwar_d_cf = d_pwar_d_loc_a[n] * d_loc_a_d_cf +
          d_pwar_d_loc_b[n] * d_loc_b_d_cf;
      } else {  // scale parameter coefficients
        int k_col = k - n_X - n_Z - n_L;
        double d_scl_a_d_cf = scl_a[n] * S_a(n, k_col);
        double d_scl_b_d_cf = scl_b[n] * S_b(n, k_col);

        d_pwar_d_cf = d_pwar_d_scl_a[n] * d_scl_a_d_cf +
          d_pwar_d_scl_b[n] * d_scl_b_d_cf;
      }

      // Calculate the log-likelihood gradient for this dispute + coefficient
      if (war_n) {  // Yes war
        grad_ll(n, k) = d_pwar_d_cf / pwar[n];

        if (win_a_n) {
          grad_ll(n, k) += d_pwin_d_cf / pwin[n];
        } else if (win_b_n) {
          grad_ll(n, k) -= d_pwin_d_cf / (1.0 - pwin[n]);
        }
      } else {     // No war
        grad_ll(n, k) = -1.0 * d_pwar_d_cf / (1.0 - pwar[n]);
      }
    }
  }

  return grad_ll;
}
